{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0939e50a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# start Milvus server\n",
    "!curl https://raw.githubusercontent.com/milvus-io/milvus/master/deployments/docker/standalone/docker-compose.yml > docker-compose.yml\n",
    "!docker-compose down\n",
    "!docker-compose up -d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b15ba381",
   "metadata": {},
   "outputs": [],
   "source": [
    "# install dependencies\n",
    "!python3 -m pip install -r requirements.jupyter.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e740a62",
   "metadata": {},
   "outputs": [],
   "source": [
    "from notebook_config import UPLOAD_PATH, SEARCH_FEATURE_PATH, LOAD_FEATURE_PATH, METRIC_TYPE, MAX_FACES, NUM_KERNEL,SIGMA,AGGREGATION_METHOD,WEIGHTS,CUDA_DEVICE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c499d61",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.utils.data as data\n",
    "\n",
    "from milvus3d.MeshNet import MeshNet\n",
    "from milvus3d.transform import Transformer\n",
    "from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility\n",
    "\n",
    "import subprocess\n",
    "import os\n",
    "import csv\n",
    "\n",
    "from collections import deque\n",
    "import functools\n",
    "import warnings\n",
    "from IPython.display import Image\n",
    "# warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34c6c8cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpy import linalg as LA\n",
    "\n",
    "class Encode:\n",
    "    \"\"\"\n",
    "    Create embedding vector for a 3d model\n",
    "\n",
    "    Input <str>: path of the preprocessed 3d model in npy format\n",
    "    Output <List>: normalized embedding vector for that 3d model\n",
    "    \"\"\"\n",
    "\n",
    "    def do_extract(self, path, transformer):\n",
    "        data = self.prepare(path)\n",
    "        return self.extract_fea(transformer, *data)\n",
    "\n",
    "    def extract_fea(self, transformer, centers, corners, normals, neighbor_index):\n",
    "        if torch.cuda.is_available():\n",
    "            centers = Variable(torch.cuda.FloatTensor(centers.cuda()))\n",
    "            corners = Variable(torch.cuda.FloatTensor(corners.cuda()))\n",
    "            normals = Variable(torch.cuda.FloatTensor(normals.cuda()))\n",
    "            neighbor_index = Variable(torch.cuda.LongTensor(neighbor_index.cuda()))\n",
    "        else:\n",
    "            centers = Variable(torch.FloatTensor(centers.cpu()))\n",
    "            corners = Variable(torch.FloatTensor(corners.cpu()))\n",
    "            normals = Variable(torch.FloatTensor(normals.cpu()))\n",
    "            neighbor_index = Variable(torch.LongTensor(neighbor_index.cpu()))\n",
    "        # get vectors\n",
    "        feat = list(transformer.get_vector(centers, corners, normals, neighbor_index).tolist())\n",
    "        return feat / LA.norm(feat)\n",
    "\n",
    "    def prepare(self, path):\n",
    "        data = np.load(path)\n",
    "        face = data['faces']\n",
    "        neighbor_index = data['neighbors']\n",
    "\n",
    "        # fill for n < max_faces with randomly picked faces\n",
    "        num_point = len(face)\n",
    "        if num_point < 1024:\n",
    "            fill_face = []\n",
    "            fill_neighbor_index = []\n",
    "            for i in range(MAX_FACES - num_point):\n",
    "                index = np.random.randint(0, num_point)\n",
    "                fill_face.append(face[index])\n",
    "                fill_neighbor_index.append(neighbor_index[index])\n",
    "            face = np.concatenate((face, np.array(fill_face)))\n",
    "            neighbor_index = np.concatenate((neighbor_index, np.array(fill_neighbor_index)))\n",
    "\n",
    "        # to tensor\n",
    "        face = torch.from_numpy(face).float()\n",
    "        neighbor_index = torch.from_numpy(neighbor_index).long()\n",
    "\n",
    "        # reorganize\n",
    "        face = face.permute(1, 0).contiguous()\n",
    "        centers, corners, normals = face[:3], face[3:12], face[12:]\n",
    "        corners = corners - torch.cat([centers, centers, centers], 0)\n",
    "\n",
    "        return centers[np.newaxis, :, :], corners[np.newaxis, :, :], normals[np.newaxis, :, :], neighbor_index[np.newaxis, :, :]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab0bd1ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_models(path):\n",
    "    models = []\n",
    "    for f in os.listdir(path):\n",
    "        if ((f.endswith(extension) for extension in\n",
    "             ['npy']) and not f.startswith('.DS_Store')):\n",
    "            models.append(os.path.join(path, f))\n",
    "    return models\n",
    "\n",
    "def extract_features(model_dir, transformer):\n",
    "    feats = []\n",
    "    names = []\n",
    "    model_list = get_models(model_dir)\n",
    "\n",
    "\n",
    "    total = len(model_list)\n",
    "    model = Encode()\n",
    "    for i, model_path in enumerate(model_list):\n",
    "        if i%1001 == 1000:\n",
    "            print(f\"Extracting features: {i} out of {total}\")\n",
    "        # create embedding for model\n",
    "        norm_feat = model.do_extract(model_path, transformer)\n",
    "        feats.append(norm_feat.tolist())\n",
    "        names.append(model_path.encode())\n",
    "\n",
    "    return feats, names\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e93c54d",
   "metadata": {},
   "source": [
    "## Download Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbc516b4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install gdown\n",
    "import gdown\n",
    "\n",
    "\n",
    "target_dir = UPLOAD_PATH + \".tar.gz\"\n",
    "if UPLOAD_PATH == \"test_data\":\n",
    "    data_gdown_path = \"https://drive.google.com/uc?id=1m0fRU6RZG1zi2cZIDpAp8a1uOpAs9Wi-\"\n",
    "elif UPLOAD_PATH == \"ModelNet40\":\n",
    "    data_gdown_path = \"https://drive.google.com/uc?id=1iJNcFliFL7zEmroBHR0iH0a40lVQ8pDR\"\n",
    "    \n",
    "if not os.path.exists('data/' + UPLOAD_PATH):\n",
    "    !gdown \"{data_gdown_path}\" -O {target_dir}\n",
    "    !tar -xf {target_dir} -C data/\n",
    "    !rm {target_dir}\n",
    "    \n",
    "# preprocess\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d170a6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# run this if you wish to preprocess it locally\n",
    "# !cd data && ./preprocess.sh true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2c9f2f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "if UPLOAD_PATH == \"ModelNet40\":\n",
    "    if not os.path.exists('data/load_feature'):\n",
    "\n",
    "        !gdown \"https://drive.google.com/uc?id=1XFonx5ubCSTzEQGvGkpX5LXgdAK3yHQX\" -O data/load_feature.tar.gz\n",
    "        !tar -xf data/load_feature.tar.gz -C data\n",
    "        !rm data/load_feature.tar.gz\n",
    "        \n",
    "elif UPLOAD_PATH == 'test_data':\n",
    "    !cd data && ./preprocess.sh true"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74392f02",
   "metadata": {},
   "source": [
    "## Load DL Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88a81734",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "weights_dir = \"models/\"\n",
    "weights = \"MeshNet_best_9192.pkl\"\n",
    "if not os.path.exists(weights_dir):\n",
    "    os.mkdir(weights_dir)\n",
    "if not os.path.exists(weights_dir + weights):\n",
    "    !gdown \"https://drive.google.com/uc?id=1t5jyJ4Ktmlck6GYhNTPVTFZuRP7wPUYq\" -O {weights_dir+weights}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d769771",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_DEVICE\n",
    "\n",
    "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "\n",
    "model = MeshNet(require_fea=False)\n",
    "model = nn.DataParallel(model)\n",
    "\n",
    "model.load_state_dict(torch.load(WEIGHTS, map_location=device))\n",
    "model.to(device)\n",
    "\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19f307e3",
   "metadata": {},
   "source": [
    "## Create Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b9f4bd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer = Transformer(model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0908d98",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if not os.path.exists(\"data/vectors.txt\"):\n",
    "    vectors, names = extract_features('data/' + LOAD_FEATURE_PATH, transformer)\n",
    "    np.savetxt(\"data/vectors.txt\", np.array(vectors), delimiter=',')\n",
    "    np.savetxt(\"data/names.txt\",np.array(names), delimiter=',', fmt=\"%s\")\n",
    "else:\n",
    "    \n",
    "    with open(\"data/names.txt\", newline='') as f:\n",
    "        reader = csv.reader(f)\n",
    "        names = list(reader)\n",
    "        names = [i[0] for i in names]\n",
    "    \n",
    "    vectors = np.genfromtxt(\"data/vectors.txt\",delimiter=',').tolist()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33f25c1a",
   "metadata": {},
   "source": [
    "## Connect to Milvus Server"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36b6eef6",
   "metadata": {},
   "outputs": [],
   "source": [
    "connections.connect(host=\"localhost\", port=19530)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3f95219",
   "metadata": {},
   "source": [
    "## Create Collection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec175829",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Delete the collection if it exists\n",
    "collection_name = \"mesh_similarity_search\"\n",
    "if utility.has_collection(collection_name):\n",
    "    collection = Collection(name=collection_name)\n",
    "    collection.drop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abd2ef05",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = 256\n",
    "default_fields = [\n",
    "    FieldSchema(name=\"id\", dtype=DataType.INT64, is_primary=True, auto_id=True),\n",
    "    FieldSchema(name=\"vector\", dtype=DataType.FLOAT_VECTOR, dim=dim)\n",
    "]\n",
    "default_schema = CollectionSchema(fields=default_fields, description=\"3d model test collection\")\n",
    "\n",
    "collection = Collection(name=collection_name, schema=default_schema)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c232d026",
   "metadata": {},
   "source": [
    "## Insert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f56ee747",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time \n",
    "t = time.time()\n",
    "\n",
    "mr = collection.insert([[i for i in vectors]])\n",
    "\n",
    "t = time.time() - t\n",
    "ids = mr.primary_keys\n",
    "\n",
    "print(f'Inserting {len(ids)} vectors took {t} seconds.')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f022e94d",
   "metadata": {},
   "source": [
    "## Create Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f6f9d7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "default_index = {\"index_type\": \"IVF_SQ8\", \"metric_type\": METRIC_TYPE, \"params\": {\"nlist\": 16384}}\n",
    "t = time.time()\n",
    "status = collection.create_index(field_name=\"vector\", index_params=default_index)\n",
    "t = time.time() - t\n",
    "print(f'Creating index for {len(ids)} vectors took {t} seconds.')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0714a62",
   "metadata": {},
   "source": [
    "## Load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "010a1e9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "collection.load()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03d39697",
   "metadata": {},
   "source": [
    "## Store mapping\n",
    "Here we need to store mapping that maps Milvus_id returned by Milvus vector database to Embedding vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa89c237",
   "metadata": {},
   "outputs": [],
   "source": [
    "# store {milvus_id: embedded vector} in a python dictionary\n",
    "milvus_to_vector = {}\n",
    "for i in range(len(names)):\n",
    "    milvus_to_vector[ids[i]] = vectors[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ef8d6cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# record milvus id\n",
    "def store_milvus_id(ids, root):\n",
    "    result = {}\n",
    "    d_list = deque(ids)\n",
    "    data_root = 'data/' + LOAD_FEATURE_PATH\n",
    "    for filename in os.listdir(data_root):\n",
    "        if ((f.endswith(extension) for extension in\n",
    "             ['npy']) and not filename.startswith('.DS_Store')):\n",
    "            result[d_list[0]] = root+'/' +filename.split('.')[0]+ \".off\"\n",
    "            d_list.popleft()\n",
    "\n",
    "    assert not d_list\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce42c115",
   "metadata": {},
   "outputs": [],
   "source": [
    "milvus_to_filename = store_milvus_id(ids, 'data/' + UPLOAD_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d21249d",
   "metadata": {},
   "source": [
    "## Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "620c70de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select a 3d model\n",
    "search_model_path = UPLOAD_PATH + \"/toilet_0001.off\"\n",
    "search_filename = search_model_path.split('/')[-1]\n",
    "search_path = '/'.join(search_model_path.split('/')[:-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1cf88db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess\n",
    "!cd data && python3 compress.py --batch \"F\" --filename {search_filename} --path {search_path}\n",
    "!docker run -it --rm -v `pwd`:/data pymesh/pymesh /bin/bash -c \"cd /data/data && python preprocess_npy.py --batch 'F' --filename {search_filename}\"\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "961187ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = Encode()\n",
    "feat = encoder.do_extract(os.path.join('data/'+SEARCH_FEATURE_PATH, search_filename.replace(\"off\",\"npz\")), transformer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a383516d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "search_params = {\"metric_type\": METRIC_TYPE, \"params\": {\"nprobe\": 16}}\n",
    "t = time.time()\n",
    "res = collection.search([feat.tolist()], anns_field=\"vector\", param=search_params, limit=3)\n",
    "t = time.time() - t\n",
    "print(f'Searching one vector in a vector database that has {len(ids)} vectors took {t} seconds.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11dc44f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parse results\n",
    "vids = [x.id for x in res[0]]\n",
    "paths = [milvus_to_filename[vids[i]] for i in range(len(vids))]\n",
    "distances = [x.distance for x in res[0]]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17a9ab33",
   "metadata": {},
   "source": [
    "## Display results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58eec941",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pygel3d import hmesh, gl_display as gl\n",
    "from pygel3d import jupyter_display as jd\n",
    "\n",
    "model = hmesh.load('data/'+search_model_path)\n",
    "print(\"filename: \" + search_model_path.split('/')[-1])\n",
    "jd.set_export_mode(True)\n",
    "jd.display(model, smooth=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c6b3cd8",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# display returned 3d models\n",
    "for return_path in paths[1:]:\n",
    "    model = hmesh.load(return_path)\n",
    "    print(\"filename: \" + return_path.split('/')[-1])\n",
    "    jd.set_export_mode(True)\n",
    "    jd.display(model, smooth=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bba582b",
   "metadata": {},
   "source": [
    "## Demo for the model compression process\n",
    "This solution compresses the models with n faces to 1024 faces to:\n",
    "1. Save computatioal resources\n",
    "2. Let the DL model focus more on the structure rather than detailed features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a374ad30",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select a 3d model\n",
    "search_model_path = UPLOAD_PATH + \"/airplane_0001.off\"\n",
    "search_filename = search_model_path.split('/')[-1]\n",
    "search_path = '/'.join(search_model_path.split('/')[:-1])\n",
    "\n",
    "# Compression\n",
    "!cd data && python3 compress.py --batch \"F\" --filename {search_filename} --path {search_path}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87fce62c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pygel3d import hmesh, gl_display as gl\n",
    "from pygel3d import jupyter_display as jd\n",
    "\n",
    "model = hmesh.load('data/' + search_model_path)\n",
    "print(\"This is the original 3d Model\")\n",
    "print(\"filename: \" + search_model_path.split('/')[-1])\n",
    "jd.set_export_mode(True)\n",
    "jd.display(model, smooth=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fa23b3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"This is the 3d Model after compression.\")\n",
    "model = hmesh.load('data/'+search_model_path.replace(UPLOAD_PATH,'search_feature'))\n",
    "print(\"filename: \" + search_model_path.split('/')[-1])\n",
    "jd.set_export_mode(True)\n",
    "jd.display(model, smooth=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
